import copy

# import networkx as nx
import matplotlib.pyplot as plt

from ModularUtils.FunctionsConstant import build_compares, draw_true_graph, getdoKey
from ModularUtils.ControllerConstants import generate_permutations


def set_sachs_nonId_graph(noise_states, latent_state, obs_state, Data_intervs):
    image_labels=None
    cf_queries= []

    DAG_desc = "sachs_nonId_graph"
    Complete_DAG_desc = "sachs_nonId_graph"
    plot_title="Sachs/Protein Signal dataset"


    Observed_DAG = {}
    Observed_DAG["PKA"] = []
    Observed_DAG["Mek"] = ["PKA"]
    Observed_DAG["Erk"] = ["PKA", "Mek"]
    Observed_DAG["Akt"] = ["PKA", "Erk"]
    label_names = list(Observed_DAG.keys())

    num_confounders = 1
    Complete_DAG = {}
    for conf in range(num_confounders):
        Complete_DAG["U" + str(conf)] = []


    latent_conf = {}
    for var in Observed_DAG:
        Complete_DAG[var] = []
        latent_conf[var] = []

    confTochild = {"U0": ["PKA", "Mek"]}

    for conf in confTochild:
        for var in confTochild[conf]:
            latent_conf[var].append(conf)
            Complete_DAG[var].append(conf)

    for var in Observed_DAG:
        Complete_DAG[var]=Complete_DAG[var]+ Observed_DAG[var]


    complete_labels = list(Complete_DAG.keys())

    draw_true_graph(Complete_DAG)


    image_labels= []
    rep_labels= []
    label_dim = {}

    for label in Observed_DAG.keys():
        label_dim[label] =  obs_state

    # label_dim = {}
    # for label in Observed_DAG.keys():
    #     label_dim[label] = {"feature": obs_state}
    # label_dim["PKCp"] = {"feature": obs_state}
    # label_dim["U1"] = {"feature": latent_state}
    # label_dim["Upka"] = {"feature": latent_state}

    for conf in confTochild:
        label_dim[conf] = latent_state

    for label in Observed_DAG:
        label_dim["n" + label] =  noise_states
    intervention_list = [
                        # {"expr":"P(PKA,Mek)" ,"obs":["PKA", "Mek"], "inter_vars":[]},
                        #  {"expr":"P(Erk,Akt)" ,"obs":["Erk","Akt"], "inter_vars":[]},
                        #  {"expr":"P(PKA,Mek,Erk,Akt)" ,"obs":["PKA", "Mek", "Erk","Akt"], "inter_vars":[]},
                         {"expr": "P(Mek|do[PKA])", "obs": ["Mek"], "inter_vars": ["PKA"]},
                         {"expr": "P(Erk|do[PKA])", "obs": ["Erk"], "inter_vars": ["PKA"]},
                         {"expr": "P(Akt|do[PKA])", "obs": ["Akt"], "inter_vars": ["PKA"]}]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    interv_queries = [{"obs": ["PKA", "Mek"] , "intervs": [{}], "expr": "P(PKA,Mek)"},
                      {"obs": ["Erk","Akt"], "intervs": [{}], "expr": "P(Erk,Akt)"},
                      {"obs": ["PKA", "Mek", "Erk","Akt"], "intervs": [{}], "expr": "P(V)"},
                    {"obs": ["Mek"], "intervs": [{"PKA":2}], "expr": "P(Mek|do[PKA=2])"},
                      {"obs": ["Erk"], "intervs": [{"PKA": 2}], "expr": "P(Erk|do[PKA=2])"},
                        {"obs": ["Akt"], "intervs": [{"PKA":2}], "expr": "P(Akt|do[PKA=2])"}]
    # for intervention in intervention_list:
    #     perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
    #     key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
    #     interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})



    exogenous = {}
    for label in label_names:
        exogenous[label] = "n" + label


    # counterfactual variables
    cflabel_names = list(Complete_DAG.keys()) + ["PKCp"]
    Twin_Network = copy.deepcopy(Complete_DAG)
    Twin_Network["PKCp"] = ["U1"]
    Twin_Network["PKCp"] = []
    Twin_Network["Mek"] = []

    cf_exogenous = None
    del cf_exogenous["Mek"]

    cf_observe=None
    cf_intervene = None
    cf_evidence=None
    twin_map = None


    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}
    train_mech_dict["PKA"] = [{'parents': [], 'intv': {}, 'compare': ['PKA','Mek']}, {'parents': ['PKA'], 'intv': {}, 'compare': []}]
    train_mech_dict["Mek"] = [{'parents': [], 'intv': {}, 'compare': ['PKA','Mek']}, {'parents': ['PKA'], 'intv': {}, 'compare': ['Mek']}]
    train_mech_dict["Erk"] = [{'parents': ['PKA','Mek'], 'intv': {}, 'compare': ['Erk']}, {'parents': ['PKA','Mek'], 'intv': {}, 'compare': ['Erk']}]
    train_mech_dict["Akt"] = [{'parents': ['PKA','Erk'], 'intv': {}, 'compare': ['Akt']}, {'parents': ['PKA','Erk'], 'intv': {}, 'compare': ['Akt']}]
    #compare: joint for which variables are needed. parents: which variables i need to intervene on

    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])

    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels,rep_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim,  plot_title


